package com.wisedu.mobile.core.ratelimit;

import com.wisedu.mobile.common.domain.SysMenu;
import com.wisedu.mobile.common.service.ISysMenuService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.ModelAndView;

import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/**
 * @Auther: Fazhan.Qiao
 * @Date: 2019/5/28 14:08
 * @Description: 限流拦截器
 * 当我们api限流是基于IP或用户的话，那么我们有一些约定：
 * 基于IP限流的话，则http request的中header里面必须带有ip信息，ip的header支持
 *   "X-Forwarded-For","X-Real-IP", "Proxy-Client-IP", "WL-Proxy-Client-IP", "HTTP_CLIENT_IP",   "HTTP_X_FORWARDED_FOR"
 *
 * 特别注意，Nginx作为分发服务器时，则Nginx forward正式客户端请求到下游微服务时，务必要把客户的真实IP塞入到header中往下传递，需要在Nginx server配置中加入
 * proxy_set_header    X-Real-IP        $remote_addr;
 * 基于用户限流的话，约定用户token的header名称为：UserToken
 */
@Component
public class RateLimitCheckInterceptor implements HandlerInterceptor {
    private String redisKeyPrefix = "#RL";

    @Autowired
    private RedisRateLimiterFactory redisRateLimiterFactory;

    @Value("${spring.application.name}")
    private String applicationName;

    @Autowired
    private ISysMenuService menuService;


    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
            throws Exception {
        if (!(handler instanceof HandlerMethod)) {
            return true;
        }
        SysMenu sysMenu = null;
        Map<String,SysMenu> menuMap = menuService.selectLimiterMenu();
        if(menuMap!=null) {
            sysMenu = menuMap.get(request.getAttribute(HandlerMapping.BEST_MATCHING_PATTERN_ATTRIBUTE).toString());
        }
        boolean isSuccess = true;
        if(sysMenu != null) {
            String baseExp = sysMenu.getBaseExp();
            String baseVal = "";
            if(!"".equals(baseExp)) {
                baseVal = eval(baseExp, request);
            }
            int permits = Integer.parseInt(sysMenu.getPermits());
            String timeUnit = sysMenu.getTimeUnit();
            String rateLimiterKey = redisKeyPrefix + ":" + applicationName + ":" + sysMenu.getUrl() + ":" + baseVal;
            RedisRateLimiter redisRatelimiter = redisRateLimiterFactory.get(rateLimiterKey, TimeUnit.valueOf(timeUnit),
                    permits);
            isSuccess = rateCheck(redisRatelimiter, rateLimiterKey, response);

            return isSuccess;
        }
        return isSuccess;
    }

    private String eval(String baseExp, HttpServletRequest request) {
        StandardEvaluationContext context = new StandardEvaluationContext();
        ExpressionParser expressionParser = new SpelExpressionParser();
        mountCookies(request, context);
        mountHeaders(request, context);
        Expression expression = expressionParser.parseExpression(baseExp);
        String baseVal = expression.getValue(context, String.class);
        if(baseVal == null) {
            baseVal = "";
        }
        return baseVal;
    }
    private void mountCookies(HttpServletRequest request, StandardEvaluationContext context) {
        HashMap<String, String> cookieMap = new HashMap<>();
        Cookie[] cookies = request.getCookies();
        if (cookies != null) {
            for (Cookie cookie : cookies) {
                cookieMap.put(cookie.getName(), cookie.getValue());
            }
        }
        context.setVariable("Cookies", cookieMap);
    }

    private void mountHeaders(HttpServletRequest request, StandardEvaluationContext context) {
        HashMap<String, String> headerMap = new HashMap();
        Enumeration<String> headerNames = request.getHeaderNames();
        if (headerNames != null) {
            while (headerNames.hasMoreElements()) {
                String headerName = headerNames.nextElement();
                headerMap.put(headerName, request.getHeader(headerName));
            }
        }
        context.setVariable("Headers", headerMap);
    }

    private boolean rateCheck(RedisRateLimiter redisRatelimiter, String keyPrefix, HttpServletResponse response)
            throws Exception {
        if (!redisRatelimiter.acquire(keyPrefix)) {
            response.setStatus(HttpStatus.FORBIDDEN.value());
            response.getWriter().print("Access denied because of exceeding access rate");
            return false;
        }
        return true;
    }


    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler,
                           ModelAndView modelAndView) throws Exception {
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex)
            throws Exception {
    }

}

